{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "intrinsic-style-transfer",
      "version": "0.3.2",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true,
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python2",
      "display_name": "Python 2"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/reiinakano/neural-painters/blob/master/notebooks/intrinsic_style_transfer.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "metadata": {
        "id": "s7hqL4XA2wMB",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Install dependencies"
      ]
    },
    {
      "metadata": {
        "id": "hCCErs452j3I",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "!pip install ipdb tqdm cloudpickle matplotlib lucid PyDrive"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "5DQcaHeG3snB",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Download checkpoint files for painters"
      ]
    },
    {
      "metadata": {
        "id": "SoGPwL-26vdb",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "!mkdir tf_vae\n",
        "!wget -O tf_vae/vae-300000.index 'https://docs.google.com/uc?export=download&id=1ulHdDxebH46m_0ZoLa2Wsz_6vStYqJQm'\n",
        "!wget -O tf_vae/vae-300000.meta 'https://docs.google.com/uc?export=download&id=1nHN_i7Ro9g0lP4y_YQCvIWrOVX1I3CJa'\n",
        "!wget -O tf_vae/vae-300000.data-00000-of-00001 'https://docs.google.com/uc?export=download&id=18rAJcUJwFJOAcjzsabtqK12udsHMZkVk'\n",
        "!wget -O tf_vae/checkpoint 'https://docs.google.com/uc?export=download&id=18U4qMNBdyvEk-Y-Mr3MNPEHSHxhcO9hn'\n",
        "\n",
        "!mkdir tf_gan3\n",
        "!wget -O tf_gan3/gan-571445.meta 'https://docs.google.com/uc?export=download&id=15kEG1Tiu2FUg5SILVt_9yOsSd3QHwVGA'\n",
        "!wget -O tf_gan3/gan-571445.index 'https://docs.google.com/uc?export=download&id=11uyFbQsRZoWa9Yq52AFXDXPjPQoGF_ER'\n",
        "!wget -O tf_gan3/gan-571445.data-00000-of-00001 'https://docs.google.com/uc?export=download&id=11cbvz-CH3KvfZEwNQ2OUujfbf6AKNoQa'\n",
        "!wget -O tf_gan3/checkpoint 'https://docs.google.com/uc?export=download&id=1A539u51t0L31Ab1M2uPUV2SsCFsNDQRo'\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "jUreUcFc3pRd",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "ZibMP7Qj9MwP",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Imports"
      ]
    },
    {
      "metadata": {
        "id": "AUOmOSbl3yWJ",
        "colab_type": "code",
        "outputId": "fb5b8e79-4f5c-4155-8cdd-ea5641431618",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 125
        }
      },
      "cell_type": "code",
      "source": [
        "import numpy as np\n",
        "\n",
        "import tensorflow as tf\n",
        "import tensorflow.contrib.layers as tcl\n",
        "\n",
        "from IPython.display import display\n",
        "import moviepy.editor as mpy\n",
        "from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter\n",
        "\n",
        "import lucid.modelzoo.vision_models as models\n",
        "from lucid.misc.io import show, load, save\n",
        "import lucid.optvis.objectives as objectives\n",
        "from lucid.optvis.objectives import wrap_objective\n",
        "import lucid.optvis.param as param\n",
        "import lucid.optvis.render as render\n",
        "import lucid.optvis.transform as transform\n",
        "from lucid.misc.redirected_relu_grad import redirected_relu_grad, redirected_relu6_grad\n",
        "from lucid.misc.gradient_override import gradient_override_map\n",
        "\n",
        "print(tf.__version__)"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Imageio: 'ffmpeg-linux64-v3.3.1' was not found on your computer; downloading it now.\n",
            "Try 1. Download from https://github.com/imageio/imageio-binaries/raw/master/ffmpeg/ffmpeg-linux64-v3.3.1 (43.8 MB)\n",
            "Downloading: 8192/45929032 bytes (0.0%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b3162112/45929032 bytes (6.9%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b6127616/45929032 bytes (13.3%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b9428992/45929032 bytes (20.5%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b12763136/45929032 bytes (27.8%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b15777792/45929032 bytes (34.4%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b19398656/45929032 bytes (42.2%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b22216704/45929032 bytes (48.4%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b25206784/45929032 bytes (54.9%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b27770880/45929032 bytes (60.5%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b31096832/45929032 bytes (67.7%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b34643968/45929032 bytes (75.4%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b37756928/45929032 bytes (82.2%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b40992768/45929032 bytes (89.3%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b44351488/45929032 bytes (96.6%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b45929032/45929032 bytes (100.0%)\n",
            "  Done\n",
            "File saved as /root/.imageio/ffmpeg/ffmpeg-linux64-v3.3.1.\n",
            "1.13.1\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "fVSKioi74hk4",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "0QtZZ2fu-aLR",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## VAE painter"
      ]
    },
    {
      "metadata": {
        "id": "hpBhJ9jO-azL",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "class ConvVAE2(object):\n",
        "  def __init__(self, reuse=False, gpu_mode=True, graph=None):\n",
        "    self.z_size = 64\n",
        "    self.reuse = reuse\n",
        "    if not gpu_mode:\n",
        "      with tf.device('/cpu:0'):\n",
        "        tf.logging.info('conv_vae using cpu.')\n",
        "        self._build_graph(graph)\n",
        "    else:\n",
        "      tf.logging.info('conv_vae using gpu.')\n",
        "      self._build_graph(graph)\n",
        "    self._init_session()\n",
        "  \n",
        "  def build_decoder(self, z, reuse=False):\n",
        "    with tf.variable_scope('decoder', reuse=reuse):\n",
        "      h = tf.layers.dense(z, 4*256, name=\"fc\")\n",
        "      h = tf.reshape(h, [-1, 1, 1, 4*256])\n",
        "      h = tf.layers.conv2d_transpose(h, 128, 5, strides=2, activation=tf.nn.relu, name=\"deconv1\")\n",
        "      h = tf.layers.conv2d_transpose(h, 64, 5, strides=2, activation=tf.nn.relu, name=\"deconv2\")\n",
        "      h = tf.layers.conv2d_transpose(h, 32, 6, strides=2, activation=tf.nn.relu, name=\"deconv3\")\n",
        "      return tf.layers.conv2d_transpose(h, 3, 6, strides=2, activation=tf.nn.sigmoid, name=\"deconv4\")\n",
        "  \n",
        "  def build_predictor(self, actions, reuse=False, is_training=False):\n",
        "    with tf.variable_scope('predictor', reuse=reuse):\n",
        "      h = tf.layers.dense(actions, 256, activation=tf.nn.leaky_relu, name=\"fc1\")\n",
        "      h = tf.layers.batch_normalization(h, training=is_training, name=\"bn1\")\n",
        "      h = tf.layers.dense(h, 64, activation=tf.nn.leaky_relu, name=\"fc2\")\n",
        "      h = tf.layers.batch_normalization(h, training=is_training, name=\"bn2\")\n",
        "      h = tf.layers.dense(h, 64, activation=tf.nn.leaky_relu, name=\"fc3\")\n",
        "      h = tf.layers.batch_normalization(h, training=is_training, name=\"bn3\")\n",
        "      return tf.layers.dense(h, self.z_size, name='fc4')\n",
        "  \n",
        "  def _build_graph(self, graph):\n",
        "    if graph is None:\n",
        "      self.g = tf.Graph()\n",
        "    else:\n",
        "      self.g = graph\n",
        "    with self.g.as_default(), tf.variable_scope('conv_vae', reuse=self.reuse):\n",
        "      \n",
        "      #### predicting part\n",
        "      self.actions = tf.placeholder(tf.float32, shape=[None, 12])\n",
        "      self.predicted_z = self.build_predictor(self.actions, is_training=False)\n",
        "      self.predicted_y = self.build_decoder(self.predicted_z)\n",
        "      \n",
        "      # initialize vars\n",
        "      self.init = tf.global_variables_initializer()\n",
        "  \n",
        "  def generate_stroke_graph(self, actions):\n",
        "    with tf.variable_scope('conv_vae', reuse=True):\n",
        "      with self.g.as_default():\n",
        "        # Encoder?\n",
        "        z = self.build_predictor(actions, reuse=True, is_training=False)\n",
        "\n",
        "        # Decoder\n",
        "        return self.build_decoder(z, reuse=True)\n",
        "\n",
        "  def _init_session(self):\n",
        "    \"\"\"Launch TensorFlow session and initialize variables\"\"\"\n",
        "    self.sess = tf.Session(graph=self.g)\n",
        "    self.sess.run(self.init)\n",
        "  def close_sess(self):\n",
        "    \"\"\" Close TensorFlow session \"\"\"\n",
        "    self.sess.close()\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "nBmf0lFwBij4",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "cT-NXDh7BjB0",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## GAN Painter"
      ]
    },
    {
      "metadata": {
        "id": "TCch9ykeB3dO",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "def relu_batch_norm(x):\n",
        "    return tf.nn.relu(tf.contrib.layers.batch_norm(x, updates_collections=None))\n",
        "\n",
        "class GeneratorConditional(object):\n",
        "    def __init__(self, divisor=1, add_noise=False):\n",
        "        self.x_dim = 64 * 64 * 3\n",
        "        self.divisor=divisor\n",
        "        self.name = 'lsun/dcgan/g_net'\n",
        "        self.add_noise = add_noise\n",
        "\n",
        "    def __call__(self, conditions, is_training):\n",
        "        with tf.contrib.framework.arg_scope([tcl.batch_norm], \n",
        "                                            is_training=is_training):\n",
        "          with tf.variable_scope(self.name) as vs:\n",
        "              bs = tf.shape(conditions)[0]\n",
        "              if self.add_noise:\n",
        "                conditions = tf.concat([conditions, tf.random.uniform([bs, 10])], axis=1)\n",
        "              fc = tcl.fully_connected(conditions, 4 * 4 * 1024/self.divisor, activation_fn=tf.identity)\n",
        "              conv1 = tf.reshape(fc, tf.stack([bs, 4, 4, 1024/self.divisor]))\n",
        "              conv1 = relu_batch_norm(conv1)\n",
        "              conv2 = tcl.conv2d_transpose(\n",
        "                  conv1, 512/self.divisor, [4, 4], [2, 2],\n",
        "                  weights_initializer=tf.random_normal_initializer(stddev=0.02),\n",
        "                  activation_fn=relu_batch_norm\n",
        "              )\n",
        "              conv3 = tcl.conv2d_transpose(\n",
        "                  conv2, 256/self.divisor, [4, 4], [2, 2],\n",
        "                  weights_initializer=tf.random_normal_initializer(stddev=0.02),\n",
        "                  activation_fn=relu_batch_norm\n",
        "              )\n",
        "              conv4 = tcl.conv2d_transpose(\n",
        "                  conv3, 128/self.divisor, [4, 4], [2, 2],\n",
        "                  weights_initializer=tf.random_normal_initializer(stddev=0.02),\n",
        "                  activation_fn=relu_batch_norm\n",
        "              )\n",
        "              conv5 = tcl.conv2d_transpose(\n",
        "                  conv4, 3, [4, 4], [2, 2],\n",
        "                  weights_initializer=tf.random_normal_initializer(stddev=0.02),\n",
        "                  activation_fn=tf.sigmoid)\n",
        "              return conv5\n",
        "\n",
        "    @property\n",
        "    def vars(self):\n",
        "        return [var for var in tf.global_variables() if self.name in var.name]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "L1nDzmbkB9Rf",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "class ConvGAN(object):\n",
        "  def __init__(self, reuse=False, gpu_mode=True, graph=None):\n",
        "    self.reuse = reuse\n",
        "    self.g_net = GeneratorConditional(divisor=4, add_noise=False)\n",
        "    \n",
        "    if not gpu_mode:\n",
        "      with tf.device('/cpu:0'):\n",
        "        tf.logging.info('conv_gan using cpu.')\n",
        "        self._build_graph(graph)\n",
        "    else:\n",
        "      tf.logging.info('conv_gan using gpu.')\n",
        "      self._build_graph(graph)\n",
        "    self._init_session()\n",
        "      \n",
        "  def _build_graph(self, graph):\n",
        "    if graph is None:\n",
        "      self.g = tf.Graph()\n",
        "    else:\n",
        "      self.g = graph\n",
        "      \n",
        "    with self.g.as_default(), tf.variable_scope('conv_gan', reuse=self.reuse):\n",
        "      self.actions = tf.placeholder(tf.float32, shape=[None, 12])\n",
        "      self.y = self.g_net(self.actions, is_training=False)\n",
        "      self.init = tf.global_variables_initializer()\n",
        "  \n",
        "  def generate_stroke_graph(self, actions):\n",
        "    with tf.variable_scope('conv_gan', reuse=True):\n",
        "      with self.g.as_default():\n",
        "        return self.g_net(actions, is_training=False)\n",
        "      \n",
        "  def _init_session(self):\n",
        "    \"\"\"Launch TensorFlow session and initialize variables\"\"\"\n",
        "    self.sess = tf.Session(graph=self.g)\n",
        "    self.sess.run(self.init)\n",
        "  def close_sess(self):\n",
        "    \"\"\" Close TensorFlow session \"\"\"\n",
        "    self.sess.close()\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "oKDyMa-oFI2p",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "lDpuowtlFMNj",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Construct the Lucid graph"
      ]
    },
    {
      "metadata": {
        "id": "797y3uqQFNqz",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# these constants help remember which image is at which batch dimension\n",
        "TRANSFER_INDEX = 0\n",
        "CONTENT_INDEX = 1\n",
        "\n",
        "content_layers = [\n",
        "  'mixed3b',\n",
        "]\n",
        "\n",
        "def mean_L1(a, b):\n",
        "  return tf.reduce_mean(tf.abs(a-b))\n",
        "\n",
        "@wrap_objective\n",
        "def activation_difference(layer_names, activation_loss_f=mean_L1, transform_f=None, difference_to=CONTENT_INDEX):\n",
        "  def inner(T):\n",
        "    # first we collect the (constant) activations of image we're computing the difference to\n",
        "    image_activations = [T(layer_name)[difference_to] for layer_name in layer_names]\n",
        "    if transform_f is not None:\n",
        "      image_activations = [transform_f(act) for act in image_activations]\n",
        "    \n",
        "    # we also set get the activations of the optimized image which will change during optimization\n",
        "    optimization_activations = [T(layer)[TRANSFER_INDEX] for layer in layer_names]\n",
        "    if transform_f is not None:\n",
        "      optimization_activations = [transform_f(act) for act in optimization_activations]\n",
        "    \n",
        "    # we use the supplied loss function to compute the actual losses\n",
        "    losses = [activation_loss_f(a, b) for a, b in zip(image_activations, optimization_activations)]\n",
        "    return tf.add_n(losses) \n",
        "    \n",
        "  return inner\n",
        "\n",
        "def import_model(model, t_image, t_image_raw, scope=\"import\"):\n",
        "\n",
        "  model.import_graph(t_image, scope=scope, forget_xy_shape=True)\n",
        "\n",
        "  def T(layer):\n",
        "    if layer == \"input\": return t_image_raw\n",
        "    if layer == \"labels\": return model.labels\n",
        "    if \":\" in layer:\n",
        "        return t_image.graph.get_tensor_by_name(\"%s/%s\" % (scope,layer))\n",
        "    else:\n",
        "        return t_image.graph.get_tensor_by_name(\"%s/%s:0\" % (scope,layer))\n",
        "\n",
        "  return T\n",
        "\n",
        "class LucidGraph(object):\n",
        "  def __init__(self, content, overlap_px=10, repeat=2, num_strokes=5, painter_type=\"GAN\", connected=True, alternate=True, bw=False, learning_rate=0.1, gpu_mode=True, graph=None):\n",
        "    self.overlap_px = overlap_px\n",
        "    self.repeat = repeat\n",
        "    self.full_size = 64*repeat - overlap_px*(repeat - 1)\n",
        "    self.unrepeated_num_strokes= num_strokes\n",
        "    self.num_strokes= num_strokes * self.repeat**2\n",
        "    self.painter_type = painter_type\n",
        "    self.connected=connected\n",
        "    self.alternate=alternate\n",
        "    self.bw = bw\n",
        "    print('full_size', self.full_size, 'max_seq_len', self.num_strokes)\n",
        "    \n",
        "    self.content=content\n",
        "    \n",
        "    self.inception_v1 = models.InceptionV1()\n",
        "    self.inception_v1.load_graphdef()\n",
        "    \n",
        "    transforms = [\n",
        "      transform.pad(12, mode='constant', constant_value=.5),\n",
        "      transform.jitter(8),\n",
        "      transform.random_scale([1 + (i-5)/50. for i in range(11)]),\n",
        "      transform.random_rotate(list(range(-5, 5)) + 5*[0]),\n",
        "      transform.jitter(4),\n",
        "    ]\n",
        "    \n",
        "    self.transform_f = render.make_transform_f(transforms)\n",
        "    \n",
        "    self.optim = render.make_optimizer(tf.train.AdamOptimizer(learning_rate), [])\n",
        "\n",
        "    self.gpu_mode = gpu_mode\n",
        "    if not gpu_mode:\n",
        "      with tf.device('/cpu:0'):\n",
        "        tf.logging.info('Model using cpu.')\n",
        "        self._build_graph(graph)\n",
        "    else:\n",
        "      #tf.logging.info('Model using gpu.')\n",
        "      self._build_graph(graph)\n",
        "    self._init_session()\n",
        "  \n",
        "  def _build_graph(self, graph):\n",
        "    if graph is None:\n",
        "      self.g = tf.Graph()\n",
        "    else:\n",
        "      self.g = graph\n",
        "    \n",
        "    # Set up graphs of VAE or GAN\n",
        "    if self.painter_type == \"GAN\":\n",
        "      self.painter = ConvGAN(\n",
        "              reuse=False,\n",
        "              gpu_mode=self.gpu_mode,\n",
        "              graph=self.g)\n",
        "    elif self.painter_type==\"VAE\":\n",
        "      self.painter = ConvVAE2(\n",
        "              reuse=False,\n",
        "              gpu_mode=self.gpu_mode,\n",
        "              graph=self.g)\n",
        "    self.painter.close_sess()\n",
        "    \n",
        "    with self.g.as_default():\n",
        "      print('GLOBAL VARS', tf.global_variables())\n",
        "    \n",
        "    with self.g.as_default():\n",
        "      batch_size = 1\n",
        "      self.actions = tf.get_variable(\"action_vars\", [batch_size, self.num_strokes, 12], \n",
        "                                     #initializer=tf.initializers.random_normal()\n",
        "                                     initializer=tf.initializers.random_uniform()\n",
        "                                    )\n",
        "      if self.bw:\n",
        "        actions2 = tf.concat([self.actions[:, :, :6], tf.zeros([1, self.num_strokes, 3]), self.actions[:, :, 9:]], axis=2)\n",
        "      else:\n",
        "        actions2 = self.actions\n",
        "      \n",
        "      self.actions_assign_ph = tf.placeholder(dtype=tf.float32)\n",
        "      self.actions_assign_op = tf.assign(self.actions, self.actions_assign_ph)\n",
        "      \n",
        "      # Prepare loop vars for rnn loop\n",
        "      canvas_state = tf.ones(shape=[batch_size, self.full_size, self.full_size, 3], dtype=tf.float32)\n",
        "      i = tf.constant(0)\n",
        "      initial_canvas_ta = tf.TensorArray(dtype=tf.float32, size=self.num_strokes)\n",
        "      loop_vars = (\n",
        "          canvas_state, \n",
        "          initial_canvas_ta, i)\n",
        "      \n",
        "      \n",
        "      # condition for continuation\n",
        "      def cond(cs, c_ta, i):\n",
        "        return tf.less(i, self.num_strokes)\n",
        "      \n",
        "      # run one state of rnn cell\n",
        "      def body(cs, c_ta, i):\n",
        "        \n",
        "        trimmed_actions = tf.sigmoid(actions2)\n",
        "        \n",
        "        print(trimmed_actions.get_shape())\n",
        "        \n",
        "        def use_whole_action():\n",
        "          return trimmed_actions[:, i, :12]\n",
        "        \n",
        "        def use_previous_entrypoint():\n",
        "          # start x and y are previous end x and y\n",
        "          # start pressure is previous pressure\n",
        "          return tf.concat([trimmed_actions[:, i, :9], trimmed_actions[:, i-1, 4:6], trimmed_actions[:, i-1, 0:1]], axis=1)\n",
        "        \n",
        "        if self.connected:\n",
        "          inp = tf.cond(tf.equal(i, 0), true_fn=use_whole_action, false_fn=use_previous_entrypoint)\n",
        "        else:\n",
        "          inp = use_whole_action()\n",
        "        inp = tf.reshape(inp, [-1, 12])\n",
        "        \n",
        "        print(inp.get_shape())\n",
        "        \n",
        "        decoded_stroke = self.painter.generate_stroke_graph(inp)\n",
        "        \n",
        "        cases = []\n",
        "        ctr = 0\n",
        "        for a in range(self.repeat):\n",
        "          for b in range(self.repeat):\n",
        "            print([int(self.repeat**2), ctr])\n",
        "            print([[0, 0], [(64-self.overlap_px)*a, (64-self.overlap_px)*(self.repeat-1-a)], [(64-self.overlap_px)*b, (64-self.overlap_px)*(self.repeat-1-b)], [0, 0]])\n",
        "            cases.append(\n",
        "              (\n",
        "                  tf.equal(tf.floormod(i, int(self.repeat**2)), ctr) if self.alternate else tf.less(i, self.unrepeated_num_strokes*(ctr+1)),\n",
        "                  lambda a=a, b=b: tf.pad(decoded_stroke, \n",
        "                                 [[0, 0], [(64-self.overlap_px)*a, (64-self.overlap_px)*(self.repeat-1-a)], [(64-self.overlap_px)*b, (64-self.overlap_px)*(self.repeat-1-b)], [0, 0]], \n",
        "                                 constant_values=1)\n",
        "              )\n",
        "            )\n",
        "            ctr += 1\n",
        "        \n",
        "        print(cases)\n",
        "        decoded_stroke = tf.case(cases)\n",
        "\n",
        "        darkness_mask = tf.reduce_mean(decoded_stroke, axis=3)\n",
        "        darkness_mask = 1 - tf.reshape(darkness_mask, [batch_size, self.full_size, self.full_size, 1])\n",
        "        darkness_mask = darkness_mask / tf.reduce_max(darkness_mask)\n",
        "        \n",
        "        color_action = trimmed_actions[:, i, 6:9]\n",
        "        color_action = tf.reshape(color_action, [batch_size, 1, 1, 3])\n",
        "        color_action = tf.tile(color_action, [1, self.full_size, self.full_size, 1])\n",
        "        stroke_whitespace = tf.equal(decoded_stroke, 1.)\n",
        "        maxed_stroke = tf.where(stroke_whitespace, decoded_stroke, color_action)\n",
        "        \n",
        "        cs = (darkness_mask)*maxed_stroke + (1-darkness_mask)*cs\n",
        "        c_ta = c_ta.write(i, cs)\n",
        "                \n",
        "        i = tf.add(i, 1)\n",
        "        return (cs, c_ta, i)\n",
        "      \n",
        "      final_canvas_state, final_canvas_ta, _ = tf.while_loop(cond, body, loop_vars, swap_memory=True)\n",
        "      self.intermediate_canvases = final_canvas_ta.stack()\n",
        "      \n",
        "      content_input = tf.image.resize_images(np.expand_dims(self.content, 0), [self.full_size, self.full_size])\n",
        "      final_canvas_state = tf.stack([final_canvas_state[0], content_input[0]])\n",
        "      print(final_canvas_state.shape)\n",
        "      self.final_canvas_state = final_canvas_state\n",
        "      self.resized_final = tf.image.resize_images(final_canvas_state, [256, 256])\n",
        "      \n",
        "      #For visualization\n",
        "      self.content_style_vis = final_canvas_state[1:]\n",
        "\n",
        "      global_step = tf.train.get_or_create_global_step()\n",
        "      \n",
        "      with gradient_override_map({'Relu': redirected_relu_grad,\n",
        "                                  'Relu6': redirected_relu6_grad}):\n",
        "        self.T = render.import_model(self.inception_v1, self.transform_f(final_canvas_state), final_canvas_state)\n",
        "        \n",
        "      content_obj = 100 * activation_difference(content_layers, difference_to=CONTENT_INDEX)\n",
        "      content_obj.description = \"Content Loss\"\n",
        "            \n",
        "      self.loss = content_obj(self.T)\n",
        "\n",
        "      self.vis_op = self.optim.minimize(self.loss, global_step=global_step, var_list=[self.actions])\n",
        "\n",
        "      # initialize vars\n",
        "      self.init = tf.global_variables_initializer()\n",
        "      \n",
        "      print('TRAINABLE', tf.trainable_variables())\n",
        "      \n",
        "  def train(self, thresholds=range(0, 5000, 30)):\n",
        "    self.images = []\n",
        "    print(self.sess.run(self.actions))\n",
        "    vis = self.sess.run(self.final_canvas_state)\n",
        "    show(np.hstack(vis[:2]))\n",
        "    try:\n",
        "      for i in range(max(thresholds)+1):\n",
        "        content_loss_, _ = self.sess.run([self.loss, self.vis_op])\n",
        "        if i in thresholds:\n",
        "          vis = self.sess.run(self.resized_final)\n",
        "          print(i, content_loss_,)\n",
        "          show(np.hstack(vis[:2]))\n",
        "    except KeyboardInterrupt:\n",
        "      vis = self.sess.run(self.final_canvas_state)\n",
        "      show(np.hstack(vis[:2]))\n",
        "\n",
        "  def _init_session(self):\n",
        "    self.sess = tf.Session(graph=self.g)\n",
        "    self.sess.run(self.init)\n",
        "  def close_sess(self):\n",
        "    self.sess.close()\n",
        "    \n",
        "      \n",
        "  def load_painter_checkpoint(self, checkpoint_path='tf_conv_vae', actual_path=None):\n",
        "    sess = self.sess\n",
        "    with self.g.as_default():\n",
        "      if self.painter_type == \"VAE\":\n",
        "        pth = 'conv_vae'\n",
        "      elif self.painter_type == \"GAN\":\n",
        "        pth = 'conv_gan'\n",
        "      saver = tf.train.Saver(tf.global_variables(pth))\n",
        "    ckpt = tf.train.get_checkpoint_state(checkpoint_path)\n",
        "    if actual_path is None:\n",
        "      actual_path = ckpt.model_checkpoint_path\n",
        "    print('loading model', actual_path)\n",
        "    tf.logging.info('Loading model %s.', actual_path)\n",
        "    saver.restore(sess, actual_path)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "KjUUZ_-IM-Ph",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Choose parameters"
      ]
    },
    {
      "metadata": {
        "id": "vfRMZhEyNvFC",
        "colab_type": "code",
        "outputId": "cd507acc-cdf6-4883-946e-9098e4b8f514",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 125
        }
      },
      "cell_type": "code",
      "source": [
        "#@title After running this cell manually, it will auto-run if you change the selected value. { run: \"auto\", display-mode: \"form\" }\n",
        "\n",
        "NUMBER_STROKES = 4 #@param {type:\"slider\", min:1, max:10, step:1}\n",
        "#@markdown Number of strokes per section. By default we have 64 sections.\n",
        "PAINTER_MODE = \"GAN\" #@param [\"GAN\", \"VAE\"]\n",
        "#@markdown GAN mode results in strokes that actually look like paintbrush strokes.\n",
        "CONNECTED_STROKES = False #@param {type:\"boolean\"}\n",
        "#@markdown If true, strokes begin at the endpoint of the previous stroke. Otherwise, strokes are independent and can start anywhere.\n",
        "BW = False #@param {type:\"boolean\"}\n",
        "#@markdown Black and white\n",
        "LEARNING_RATE = 0.1 #@param {type: \"number\"}\n",
        "\n",
        "print(\"Number of strokes\", NUMBER_STROKES)\n",
        "print(\"Using {} painter\".format(PAINTER_MODE))\n",
        "print(\"Using connected strokes\", CONNECTED_STROKES)\n",
        "print(\"Grayscale\", BW)\n",
        "print(\"Learning Rate\", LEARNING_RATE)\n",
        "print('--------------------')"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "('Number of strokes', 4)\n",
            "Using GAN painter\n",
            "('Using connected strokes', False)\n",
            "('Grayscale', False)\n",
            "('Learning Rate', 0.1)\n",
            "--------------------\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "JwNK9D6QCEtH",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Input your content image here.\n",
        "\n",
        "The `load` function takes a link or local filepath. Input images will be forced to squares."
      ]
    },
    {
      "metadata": {
        "id": "sH06RYu-8hzy",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# Load from a URL\n",
        "CONTENT_IMAGE = load(\"https://storage.googleapis.com/tensorflow-lucid/static/img/notebook-styletransfer-bigben.jpg\")[..., :3]  # Remove transparency channel\n",
        "\n",
        "# Or load from a local path\n",
        "#CONTENT_IMAGE = load(\"local_path.jpg\")[..., :3]  # Remove transparency channel\n",
        "\n",
        "show(CONTENT_IMAGE)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "J3MTakNPQ3Du",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Run!"
      ]
    },
    {
      "metadata": {
        "id": "ntvHiCwgHi56",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "lol = LucidGraph(CONTENT_IMAGE, 32, 8, NUMBER_STROKES, painter_type=PAINTER_MODE, connected=CONNECTED_STROKES, alternate=False, bw=BW, learning_rate=LEARNING_RATE)\n",
        "\n",
        "if PAINTER_MODE == \"GAN\":\n",
        "  lol.load_painter_checkpoint('tf_gan3')\n",
        "elif PAINTER_MODE == \"VAE\":\n",
        "  lol.load_painter_checkpoint('tf_vae')\n",
        "lol.train()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "Ket_vOx0Tvn5",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Evaluate results"
      ]
    },
    {
      "metadata": {
        "id": "OSaIOM9PH1n8",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "np.save('actions', lol.sess.run(lol.actions))\n",
        "_acs = np.load('actions.npy')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "V4B5syw5H1lW",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "-vkWTmT7H1jF",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "#lol = LucidGraph(CONTENT_IMAGE, dummy, 32, 8, NUMBER_STROKES, painter_type=PAINTER_MODE, connected=CONNECTED_STROKES, bw=BW, alternate=False, learning_rate=LEARNING_RATE)\n",
        "#lol.load_painter_checkpoint('tf_gan3')\n",
        "_int_canvases, _content_style = lol.sess.run([lol.intermediate_canvases, lol.content_style_vis], feed_dict={lol.actions: _acs})"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "NA5Hqbor-UiB",
        "colab_type": "code",
        "outputId": "54efd93b-400c-4d97-be73-db5979889726",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "cell_type": "code",
      "source": [
        "_SIZE = lol.full_size\n",
        "\n",
        "_stacked_plots = []\n",
        "for _target in range(0, 1):\n",
        "\n",
        "  _intermediate_canvases_to_plot = np.repeat(_int_canvases[::2][:, _target, :, :, :], 1, axis=0)\n",
        "  _target_images = np.tile(np.concatenate(_content_style, axis=1).reshape(1, _SIZE, _SIZE, 3), \n",
        "                           [len(_intermediate_canvases_to_plot), 1, 1, 1])\n",
        "\n",
        "  print(_intermediate_canvases_to_plot.shape)\n",
        "\n",
        "  _plot = _intermediate_canvases_to_plot\n",
        "  _plot = np.concatenate([\n",
        "      _target_images[:, :, :_SIZE, :], \n",
        "      _plot, \n",
        "      #_target_images[:, :, _SIZE:, :]\n",
        "  ], axis=2)\n",
        "  \n",
        "  _plot = np.concatenate([_plot, np.tile(_plot[-1:, :, :, :], [50, 1, 1, 1])], axis=0)\n",
        "  \n",
        "  _stacked_plots.append(_plot)\n",
        "\n",
        "\n",
        "#imageio.mimsave('hello2.gif', np.concatenate(_stacked_plots),'GIF', fps=16)"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "(128, 288, 288, 3)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "ju7yZITx-Uk_",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "from IPython.display import display\n",
        "\n",
        "import moviepy.editor as mpy\n",
        "from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter\n",
        "\n",
        "def vid(my_frames):\n",
        "  \n",
        "  def frame(t):\n",
        "    t = int(t*10.)\n",
        "    if t >= len(my_frames):\n",
        "      t = len(my_frames)-1\n",
        "    return ((my_frames[t])*255).astype(np.float)\n",
        "\n",
        "  clip = mpy.VideoClip(frame, duration=len(my_frames)/10)\n",
        "  clip.write_videofile('tmp.mp4', fps=10.)\n",
        "  display(mpy.ipython_display('tmp.mp4', height=400, max_duration=100.))\n",
        "vid(np.concatenate(_stacked_plots))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "XOCmFG8E-3NB",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}